Skip to content

Feat/r3 prod v2#2593

Draft
samsja wants to merge 7 commits into
feat/r3-prodfrom
feat/r3-prod-v2
Draft

Feat/r3 prod v2#2593
samsja wants to merge 7 commits into
feat/r3-prodfrom
feat/r3-prod-v2

Conversation

@samsja
Copy link
Copy Markdown
Member

@samsja samsja commented May 22, 2026

Note

High Risk
High risk because it changes core RL training semantics (advantage computation/group finalization, loss scaling and gradient normalization) and expands trainer/orchestrator data contracts with new per-token exports and env_name plumbing, which can affect training stability and compatibility with existing custom advantage functions.

Overview
Changes advantage computation to be per-group and robust to rollout failures. AdvantageInputs.rollouts is now a single group (list[vf.RolloutOutput]) and AdvantageOutputs.advantages is now list[float]; compute_advantages now groups by (env_name, example_id) and supports partial groups when some rollouts error.

Updates rollout scheduling/failure handling to finalize partial groups instead of rescheduling by rounds. The scheduler now tracks failed rollouts per group, drops group-scoring groups on any failure, and otherwise proceeds once all dispatched rollouts have returned, sampling only the valid completions.

Adds trainer-side debugging/observability hooks. Introduces opt-in trainer.experimental.token_export to write per-token JSONL exports, plumbs env_name through TrainingSample/MicroBatch (plus per-token env_names and optional rewards), and logs entropy/mismatch_kl with {all,env} breakdowns plus new masking diagnostics.

Adjusts trainer loss normalization across distributed ranks. Loss scaling is now based on the global (dp_cp) count of unmasked tokens and gradients are rescaled to counteract FSDP’s per-rank averaging.

Reviewed by Cursor Bugbot for commit 540f42c. Bugbot is set up for automated code reviews on this repo. Configure here.

samsja and others added 7 commits May 22, 2026 01:46
…ut error (#2590)

* feat(scheduler): train on partial groups instead of dropping on rollout error

When an individual-scoring env returns N rollouts for a single example and one
errors out, scrap only the failed rollout — keep the survivors and ship the
group through as soon as every dispatched rollout has come back (success or
failure). Group-scoring envs still drop the whole group on any failure because
their per-rollout scores are computed against the now-missing rollouts.

To make variable-size groups round-trip through advantage computation, group
rollouts by (env_name, example_id) instead of positional slicing, and bucket
groups by size so each advantage_fn call still sees a uniform 2D rewards
tensor. Singleton groups produce zero advantage and get filtered out by the
existing zero-advantage filter — no special-casing.

Closes #2585.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* refactor(advantage): make advantage_fn per-group

Drop the bucket-by-size workaround in compute_advantages by changing the
advantage_fn contract: AdvantageInputs.rollouts is now a single group
(list[RolloutOutput]) and AdvantageOutputs.advantages is 1D. The framework
calls advantage_fn once per group, which works cleanly for variable-size
groups (partial-group training).

BREAKING: second change to this public API in three weeks. Custom advantage
functions must drop the outer list dim. Migration documented in CHANGELOG.md
and docs/bring-your-own-algorithms.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* refactor(advantage): return list[float] from advantage_fn

Drops the torch tensor from the public AdvantageOutputs contract; internal
math stays in torch and converts via .tolist() at the boundary. Same partial-
group support, simpler downstream consumers (no more .tolist() / no shape
gymnastics in custom advantages).

BREAKING (folds into the per-group change in the previous commit): custom
advantage functions must return AdvantageOutputs(advantages=[...]) rather
than a tensor. CHANGELOG entry and docs example updated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* refactor(scheduler): log each rollout failure inline, drop GroupState.last_failure_reason

The reason is now logged at the moment the failure is observed (one warning
per failed rollout) instead of being stashed on the group and replayed at
finalization. Removes the per-group field entirely and avoids the "first vs
latest wins" semantic question that came up in review - each log line carries
its own actual reason. Finalization warnings only carry counts now.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore(scheduler): drop verbose comment on GroupState.failed_rollouts

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…2589)

* feat(orchestrator): per-env state_columns for extra rollout fields

Adds `state_columns: list[str] = []` to `EnvConfig` so each env can
persist additional `State` fields into the saved JSONL rollouts on top
of the always-saved `trajectory` and `sampling_args`. Merged at the
call site (required first, deduped).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* refactor: drop seen set from state_columns dedup

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each DP rank was dividing its summed token loss by its own local
loss_scale, then FSDP averaged the resulting gradients. Because ranks
process different sequence lengths, that mean is not the true per-token
mean over the global batch — ranks with fewer loss tokens get implicitly
upweighted.

Mirror the SFT trainer fix (src/prime_rl/trainer/sft/train.py:416-427):
all-reduce the local token count across dp_cp, divide by that global
denominator on every rank, and multiply grads by fsdp_gradient_divide_factor
after the microbatch loop so FSDP's per-rank averaging is undone and the
final gradient is the per-token mean over the global batch.

Closes #2358. Adapted from #2359, which first diagnosed the bias and
proposed the all-reduce-then-rescale approach.

Co-authored-by: irfanjamil <irfanjamil9@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* add per-env trainer metrics

* require env names for trainer batches

* address per-env metric review feedback

* reuse tensor stats for env metrics

* fix sft trajectory env-name fixture

* address trainer metric naming comments

* fix: reuse trainer ratio tensors for env metrics

* fix: derive dppo mask from shared ratio

* address PR review: drop precomputed loss inputs, use {all,env} keys

- compute importance-ratio / mismatch_kl inside the loss functions instead of
  passing them in via LossInputs (per Mika)
- compute mismatch_kl inline in train.py only for per-env logging
- rename trainer aggregate keys to entropy/all and mismatch_kl/all to match the
  orchestrator {all,env}/{mean,std,max} convention; drop the leftover bare
  entropy/* and mismatch_kl/* keys
- drop the overly defensive env_names length check in DataLoader

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* guard mismatch_kl logging behind sft_loss flag

SFT batches don't have meaningful inference_logprobs (sft_loss_fn ignores
them), so computing and logging mismatch_kl for those microbatches is wasted
work and produces misleading numbers. Skip the inline mismatch_kl compute and
the per-env / debug-log emissions when sft_loss is True.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* address review: simplify probs_diff and move mismatch_kl/all to trainer loop

* reserve env_name='all' for aggregate metric keys

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* feat: add trainer token jsonl export

* chore: use docstrings for token export config
@samsja samsja marked this pull request as ready for review May 22, 2026 20:17
@samsja samsja marked this pull request as draft May 22, 2026 20:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants